{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2019 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Music Recommendation using AutoML Tables\n",
    "\n",
    "## Overview\n",
    "In this notebook we will see how [AutoML Tables](https://cloud.google.com/automl-tables/) can be used to make music recommendations to users. AutoML Tables is a supervised learning service for structured data that can vastly simplify the model building process.\n",
    "\n",
    "### Dataset\n",
    "AutoML Tables allows data to be imported from either GCS or BigQuery. This tutorial uses the [ListenBrainz](https://console.cloud.google.com/marketplace/details/metabrainz/listenbrainz) dataset from [Cloud Marketplace](https://console.cloud.google.com/marketplace), hosted in BigQuery.\n",
    "\n",
    "The ListenBrainz dataset is a log of songs played by users, some notable pieces of the schema include:\n",
    "  - **user_name:** a user id.\n",
    "  - **track_name:** a song id.\n",
    "  - **artist_name:** the artist of the song.\n",
    "  - **release_name:** the album of the song.\n",
    "  - **tags:** the genres of the song.\n",
    "\n",
    "### Objective\n",
    "The goal of this notebook is to demonstrate how to create a lookup table in BigQuery of songs to recommend to users using a log of user-song listens and AutoML Tables. This will be done by training a binary classification model to predict whether or not a `user` will like a given `song`. In the training data, liking a song was defined as having listened to a song more than twice. **Using the  predictions for every `(user, song)` pair to generate a ranking of the most similar songs for each user.**\n",
    "\n",
    "As the number of `(user, song)` pairs grows exponentially with the number of unique users and songs, this approach may not be optimal for extremely large datasets. One workaround would be to train a model that learns to embed users and songs in the same embedding space, and use a nearest-neighbors algorithm to get recommendations for users. Unfortunately, AutoML Tables does not expose any feature for training and using embeddings, so a [custom ML model](https://github.com/GoogleCloudPlatform/professional-services/tree/master/examples/cloudml-collaborative-filtering) would need to be used instead.\n",
    "\n",
    "Another recommendation approach that is worth mentioning is [using extreme multiclass classification](https://ai.google/research/pubs/pub45530), as that also circumvents storing every possible pair of users and songs. Unfortunately, AutoML Tables does not support the multiclass classification of more than [100 classes](https://cloud.google.com/automl-tables/docs/prepare#target-requirements).\n",
    "\n",
    "### Costs\n",
    "This tutorial uses billable components of Google Cloud Platform (GCP):\n",
    "- Cloud AutoML Tables\n",
    "\n",
    "Learn about [AutoML Tables pricing](https://cloud.google.com/automl-tables/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Follow the [AutoML Tables documentation](https://cloud.google.com/automl-tables/docs/) to\n",
    "* [Enable billing](https://cloud.google.com/billing/docs/how-to/modify-project).\n",
    "* [Enable AutoML API](https://console.cloud.google.com/apis/library/automl.googleapis.com?q=automl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 PIP Install Packages and dependencies\n",
    "Install addional dependencies not installed in the notebook environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip3 install --upgrade --quiet google-cloud-automl google-cloud-bigquery"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Restart the kernel__ to allow `automl_v1beta1` to be imported. The following cell should succeed after a kernel restart:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.cloud import automl_v1beta1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Import libraries and define constants"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Populate the following cell with the necessary constants and run it to initialize constants and create clients for BigQuery and AutoML Tables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The GCP project id.\n",
    "PROJECT_ID = \"\"\n",
    "# The region to use for compute resources (AutoML isn't supported in some regions).\n",
    "LOCATION = \"us-central1\"\n",
    "# A name for the AutoML tables Dataset to create.\n",
    "DATASET_DISPLAY_NAME = \"\"\n",
    "# The BigQuery dataset to import data from (doesn't need to exist).\n",
    "INPUT_BQ_DATASET = \"\"\n",
    "# The BigQuery table to import data from (doesn't need to exist).\n",
    "INPUT_BQ_TABLE = \"\"\n",
    "# A name for the AutoML tables model to create.\n",
    "MODEL_DISPLAY_NAME = \"\"\n",
    "# The number of hours to train the model.\n",
    "MODEL_TRAIN_HOURS = 0\n",
    "\n",
    "assert all([\n",
    "    PROJECT_ID,\n",
    "    LOCATION,\n",
    "    DATASET_DISPLAY_NAME,\n",
    "    INPUT_BQ_DATASET,\n",
    "    INPUT_BQ_TABLE,\n",
    "    MODEL_DISPLAY_NAME,\n",
    "    MODEL_TRAIN_HOURS,\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import relevant packages and initialize clients for BigQuery and AutoML Tables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "from google.cloud import automl_v1beta1\n",
    "from google.cloud import bigquery\n",
    "from google.cloud import exceptions\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "tables_client = automl_v1beta1.TablesClient(project=PROJECT_ID, region=LOCATION)\n",
    "bq_client = bigquery.Client()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create a Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to train a model, a structured dataset must be injested into AutoML tables from either BigQuery or Google Cloud Storage. Once injested, the user will be able to cherry pick columns to use as features, labels, or weights and configure the loss function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Create BigQuery table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, do some feature engineering on the original ListenBrainz dataset to turn it into a dataset for training and export it into a seperate BigQuery table:\n",
    "\n",
    "    1. Make each sample a unique `(user, song)` pair.\n",
    "    2. For features, use the user's top 10 songs ever played and the song's number of albums, artist, and genres.\n",
    "    3. For a label, use the number of times the user has listened to the song, normalized by dividing by the maximum number of times that user has listened to any song. Normalizing the listen counts ensures active users don't have disproportionate effect on the model error.\n",
    "    4. Add a weight equal to the label to give songs more popular with the user higher weights. This is to help account for the skew in the label distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"\"\"\n",
    "   WITH\n",
    "    songs AS (\n",
    "      SELECT CONCAT(track_name, \" by \", artist_name) AS song,\n",
    "        MAX(tags) as tags\n",
    "      FROM `listenbrainz.listenbrainz.listen`\n",
    "      GROUP BY song\n",
    "      HAVING tags != \"\"\n",
    "      ORDER BY COUNT(*) DESC\n",
    "      LIMIT 10000\n",
    "    ),\n",
    "    user_songs AS (\n",
    "      SELECT user_name AS user, ANY_VALUE(artist_name) AS artist,\n",
    "        CONCAT(track_name, \" by \", artist_name) AS song,\n",
    "        SPLIT(ANY_VALUE(songs.tags), \",\") AS tags,\n",
    "        COUNT(*) AS user_song_listens\n",
    "      FROM `listenbrainz.listenbrainz.listen`\n",
    "      JOIN songs ON songs.song = CONCAT(track_name, \" by \", artist_name)\n",
    "      GROUP BY user_name, song\n",
    "    ),\n",
    "    user_tags AS (\n",
    "      SELECT user, tag, COUNT(*) AS COUNT\n",
    "      FROM user_songs,\n",
    "      UNNEST(tags) tag\n",
    "      WHERE tag != \"\"\n",
    "      GROUP BY user, tag\n",
    "    ),\n",
    "    top_tags AS (\n",
    "      SELECT tag\n",
    "      FROM user_tags\n",
    "      GROUP BY tag\n",
    "      ORDER BY SUM(count) DESC\n",
    "      LIMIT 20\n",
    "    ),\n",
    "    tag_table AS (\n",
    "      SELECT user, b.tag\n",
    "      FROM user_tags a, top_tags b\n",
    "      GROUP BY user, b.tag\n",
    "    ),\n",
    "    user_tag_features AS (\n",
    "      SELECT user,\n",
    "        ARRAY_AGG(IFNULL(count, 0) ORDER BY tag) as user_tags,\n",
    "        SUM(count) as tag_count\n",
    "      FROM tag_table\n",
    "      LEFT JOIN user_tags USING (user, tag)\n",
    "      GROUP BY user\n",
    "    ), user_features AS (\n",
    "      SELECT user, MAX(user_song_listens) AS user_max_listen,\n",
    "        ANY_VALUE(user_tags)[OFFSET(0)]/ANY_VALUE(tag_count) as user_tags0,\n",
    "        ANY_VALUE(user_tags)[OFFSET(1)]/ANY_VALUE(tag_count) as user_tags1,\n",
    "        ANY_VALUE(user_tags)[OFFSET(2)]/ANY_VALUE(tag_count) as user_tags2,\n",
    "        ANY_VALUE(user_tags)[OFFSET(3)]/ANY_VALUE(tag_count) as user_tags3,\n",
    "        ANY_VALUE(user_tags)[OFFSET(4)]/ANY_VALUE(tag_count) as user_tags4,\n",
    "        ANY_VALUE(user_tags)[OFFSET(5)]/ANY_VALUE(tag_count) as user_tags5,\n",
    "        ANY_VALUE(user_tags)[OFFSET(6)]/ANY_VALUE(tag_count) as user_tags6,\n",
    "        ANY_VALUE(user_tags)[OFFSET(7)]/ANY_VALUE(tag_count) as user_tags7,\n",
    "        ANY_VALUE(user_tags)[OFFSET(8)]/ANY_VALUE(tag_count) as user_tags8,\n",
    "        ANY_VALUE(user_tags)[OFFSET(9)]/ANY_VALUE(tag_count) as user_tags9,\n",
    "        ANY_VALUE(user_tags)[OFFSET(10)]/ANY_VALUE(tag_count) as user_tags10,\n",
    "        ANY_VALUE(user_tags)[OFFSET(11)]/ANY_VALUE(tag_count) as user_tags11,\n",
    "        ANY_VALUE(user_tags)[OFFSET(12)]/ANY_VALUE(tag_count) as user_tags12,\n",
    "        ANY_VALUE(user_tags)[OFFSET(13)]/ANY_VALUE(tag_count) as user_tags13,\n",
    "        ANY_VALUE(user_tags)[OFFSET(14)]/ANY_VALUE(tag_count) as user_tags14,\n",
    "        ANY_VALUE(user_tags)[OFFSET(15)]/ANY_VALUE(tag_count) as user_tags15,\n",
    "        ANY_VALUE(user_tags)[OFFSET(16)]/ANY_VALUE(tag_count) as user_tags16,\n",
    "        ANY_VALUE(user_tags)[OFFSET(17)]/ANY_VALUE(tag_count) as user_tags17,\n",
    "        ANY_VALUE(user_tags)[OFFSET(18)]/ANY_VALUE(tag_count) as user_tags18,\n",
    "        ANY_VALUE(user_tags)[OFFSET(19)]/ANY_VALUE(tag_count) as user_tags19\n",
    "      FROM user_songs\n",
    "      LEFT JOIN user_tag_features USING (user)\n",
    "      GROUP BY user\n",
    "      HAVING COUNT(*) < 5000 AND user_max_listen > 2\n",
    "    ),\n",
    "    item_features AS (\n",
    "      SELECT CONCAT(track_name, \" by \", artist_name) AS song,\n",
    "        COUNT(DISTINCT(release_name)) AS albums\n",
    "      FROM `listenbrainz.listenbrainz.listen`\n",
    "      WHERE track_name != \"\"\n",
    "      GROUP BY song\n",
    "    )\n",
    "  SELECT user, song, artist, tags, albums,\n",
    "    user_tags0,\n",
    "    user_tags1,\n",
    "    user_tags2,\n",
    "    user_tags3,\n",
    "    user_tags4,\n",
    "    user_tags5,\n",
    "    user_tags6,\n",
    "    user_tags7,\n",
    "    user_tags8,\n",
    "    user_tags9,\n",
    "    user_tags10,\n",
    "    user_tags11,\n",
    "    user_tags12,\n",
    "    user_tags13,\n",
    "    user_tags14,\n",
    "    user_tags15,\n",
    "    user_tags16,\n",
    "    user_tags17,\n",
    "    user_tags18,\n",
    "    user_tags19,\n",
    "    IF(user_song_listens > 2, \n",
    "       SQRT(user_song_listens/user_max_listen), \n",
    "       .5/user_song_listens) AS weight,\n",
    "    IF(user_song_listens > 2, 1, 0) as label\n",
    "  FROM user_songs\n",
    "  JOIN user_features USING(user)\n",
    "  JOIN item_features USING(song)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_table_from_query(query, table):\n",
    "    \"\"\"Creates a new table using the results from the given query.\n",
    "    \n",
    "    Args:\n",
    "        query: a query string.\n",
    "        table: a name to give the new table.\n",
    "    \"\"\"\n",
    "    job_config = bigquery.QueryJobConfig()\n",
    "    bq_dataset = bigquery.Dataset(\"{0}.{1}\".format(PROJECT_ID, INPUT_BQ_DATASET))\n",
    "    bq_dataset.location = \"US\"\n",
    "\n",
    "    try:\n",
    "        bq_dataset = bq_client.create_dataset(bq_dataset)\n",
    "    except exceptions.Conflict:\n",
    "        pass\n",
    "\n",
    "    table_ref = bq_client.dataset(INPUT_BQ_DATASET).table(table)\n",
    "    job_config.destination = table_ref\n",
    "\n",
    "    query_job = bq_client.query(query,\n",
    "                             location=bq_dataset.location,\n",
    "                             job_config=job_config)\n",
    "\n",
    "    query_job.result()\n",
    "    print('Query results loaded to table {}'.format(table_ref.path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_table_from_query(query, INPUT_BQ_TABLE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Create AutoML Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a Dataset by importing the BigQuery table that was just created. Importing data may take a few minutes or hours depending on the size of your data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tables_client.create_dataset(\n",
    "    dataset_display_name=DATASET_DISPLAY_NAME)\n",
    "\n",
    "dataset_bq_input_uri = 'bq://{0}.{1}.{2}'.format(\n",
    "    PROJECT_ID, INPUT_BQ_DATASET, INPUT_BQ_TABLE)\n",
    "import_data_response = tables_client.import_data(\n",
    "    dataset=dataset, bigquery_input_uri=dataset_bq_input_uri)\n",
    "import_data_result = import_data_response.result()\n",
    "import_data_result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Inspect the datatypes assigned to each column. In this case, the `song` and `artist` should be categorical, not textual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_column_specs_response = tables_client.list_column_specs(\n",
    "    dataset_display_name=DATASET_DISPLAY_NAME)\n",
    "column_specs = {s.display_name: s for s in list_column_specs_response}\n",
    "\n",
    "def print_column_specs(column_specs):\n",
    "    \"\"\"Parses the given specs and prints each column and column type.\"\"\"\n",
    "    data_types = automl_v1beta1.proto.data_types_pb2\n",
    "    return [(x, data_types.TypeCode.Name(\n",
    "        column_specs[x].data_type.type_code)) for x in column_specs.keys()]\n",
    "\n",
    "print_column_specs(column_specs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Update Dataset params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes, the types AutoML Tables automatically assigns each column will be off from that they were intended to be. When that happens, we need to update Tables with different types for certain columns.\n",
    "\n",
    "In this case, set the `song` and `artist` column types to `CATEGORY`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for col in [\"song\", \"artist\"]:\n",
    "    tables_client.update_column_spec(dataset_display_name=DATASET_DISPLAY_NAME,\n",
    "                                     column_spec_display_name=col,\n",
    "                                     type_code=\"CATEGORY\")\n",
    "\n",
    "list_column_specs_response = tables_client.list_column_specs(\n",
    "    dataset_display_name=DATASET_DISPLAY_NAME)\n",
    "column_specs = {s.display_name: s for s in list_column_specs_response}\n",
    "print_column_specs(column_specs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not all columns are feature columns, in order to train a model, we need to tell Tables which column should be used as the target variable and, optionally, which column should be used as sample weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tables_client.set_target_column(dataset_display_name=DATASET_DISPLAY_NAME,\n",
    "                                column_spec_display_name=\"label\")\n",
    "\n",
    "tables_client.set_weight_column(dataset_display_name=DATASET_DISPLAY_NAME,\n",
    "                                column_spec_display_name=\"weight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Create a Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the Dataset has been configured correctly, we can tell AutoML Tables to train a new model. The amount of resources spent to train this model can be adjusted using a parameter called `train_budget_milli_node_hours`. As the name implies, this puts a maximum budget on how many resources a training job can use up before exporting a servable model.\n",
    "\n",
    "Even with a budget of 1 node hour (the minimum possible budget), training a model can take several hours."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tables_client.create_model(\n",
    "    model_display_name=MODEL_DISPLAY_NAME,\n",
    "    dataset_display_name=DATASET_DISPLAY_NAME,\n",
    "    train_budget_milli_node_hours= MODEL_TRAIN_HOURS * 1000).result()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Model Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because we are optimizing a surrogate problem (predicting the similarity between `(user, song)` pairs) in order to achieve our final objective of producing a list of recommended songs for a user, it's difficult to tell how well the model performs by looking only at the final loss function. Instead, an evaluation metric we can use for our model is `recall@n` for the top `m` most listened to songs for each user. This metric will give us the probability that one of a user's top `m` most listened to songs will appear in the top `n` recommendations we make.\n",
    "\n",
    "In order to get the top recommendations for each user, we need to create a batch job to predict similarity scores between each user and item pair. These similarity scores would then be sorted per user to produce an ordered list of recommended songs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 Create an evaluation table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of creating a lookup table for all users, let's just focus on the performance for a few users for this demo. We will focus especially on recommendations for the user `rob`, and demonstrate how the others can be included in an overall evaluation metric for the model. We start by creatings a dataset for prediction to feed into the trained model; this is a table of every possible `(user, song)` pair containing the users and corresponding features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "users = [\"rob\", \"fiveofoh\", \"Aerion\"]\n",
    "training_table = \"{}.{}.{}\".format(PROJECT_ID, INPUT_BQ_DATASET, INPUT_BQ_TABLE)\n",
    "query = \"\"\"\n",
    "    WITH user as (\n",
    "      SELECT user, \n",
    "        user_tags0, user_tags1, user_tags2, user_tags3, user_tags4,\n",
    "        user_tags5, user_tags6, user_tags7, user_tags8, user_tags9,\n",
    "        user_tags10,user_tags11, user_tags12, user_tags13, user_tags14,\n",
    "        user_tags15, user_tags16, user_tags17, user_tags18, user_tags19, label\n",
    "      FROM `{0}`\n",
    "      WHERE user in ({1})\n",
    "    )\n",
    "    SELECT ANY_VALUE(a).*, song, ANY_VALUE(artist) as artist,\n",
    "      ANY_VALUE(tags) as tags, ANY_VALUE(albums) as albums\n",
    "    FROM `{0}`, user a\n",
    "    GROUP BY song\n",
    "\"\"\".format(training_table, \",\".join([\"\\\"{}\\\"\".format(x) for x in users]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_table = \"{}_example\".format(INPUT_BQ_TABLE)\n",
    "create_table_from_query(query, eval_table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Make predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the prediction table is created, start a batch prediction job. This may take a few minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_bq_input_uri = \"bq://{}.{}.{}\".format(PROJECT_ID, INPUT_BQ_DATASET, eval_table)\n",
    "preds_bq_output_uri = \"bq://{}\".format(PROJECT_ID)\n",
    "response = tables_client.batch_predict(model_display_name=MODEL_DISPLAY_NAME,\n",
    "                                       bigquery_input_uri=preds_bq_input_uri,\n",
    "                                       bigquery_output_uri=preds_bq_output_uri)\n",
    "response.result()\n",
    "output_uri = response.metadata.batch_predict_details.output_info.bigquery_output_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the similarity predictions for `rob`, we can order by the predictions to get a ranked list of songs to recommend to `rob`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "query = \"\"\"\n",
    "    SELECT user, song, tables.score as score, a.label as pred_label,\n",
    "      b.label as true_label\n",
    "    FROM `{}.predictions` a, UNNEST(predicted_label)\n",
    "    LEFT JOIN `{}` b USING(user, song)\n",
    "    WHERE user = \"{}\" AND CAST(tables.value AS INT64) = 1\n",
    "    ORDER BY score DESC\n",
    "    LIMIT {}\n",
    "\"\"\".format(output_uri[5:].replace(\":\", \".\"), training_table, users[0], n)\n",
    "query_job = bq_client.query(query)\n",
    "\n",
    "print(\"Top {} song recommended for {}:\".format(n, users[0]))\n",
    "for idx, row in enumerate(query_job):\n",
    "    print(\"{}.\".format(idx + 1), row[\"song\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 Evaluate predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Precision@k and Recall@k\n",
    "\n",
    "To evaluate the recommendations, we can look at the precision@k and recall@k of our predictions for `rob`. Run the cells below to load the recommendations into a pandas dataframe and plot the precisions and recalls at various top-k recommendations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"\"\"\n",
    "    WITH \n",
    "      top_k AS (\n",
    "        SELECT user, song, label,\n",
    "          ROW_NUMBER() OVER (PARTITION BY user ORDER BY label + weight DESC) as user_rank\n",
    "        FROM `{0}`\n",
    "      )\n",
    "    SELECT user, song, tables.score as score, b.label,\n",
    "      ROW_NUMBER() OVER (ORDER BY tables.score DESC) as rank, user_rank\n",
    "    FROM `{1}.predictions` a, UNNEST(predicted_label)\n",
    "    LEFT JOIN top_k b USING(user, song)\n",
    "    WHERE CAST(tables.value AS INT64) = 1\n",
    "    ORDER BY score DESC\n",
    "\"\"\".format(training_table, output_uri[5:].replace(\":\", \".\"))\n",
    "\n",
    "df = bq_client.query(query).result().to_dataframe()\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "precision_at_k = {}\n",
    "recall_at_k = {}\n",
    "\n",
    "for user in users:\n",
    "    precision_at_k[user] = []\n",
    "    recall_at_k[user] = []\n",
    "    for k in range(1, 1000):\n",
    "        precision = df[\"label\"][:k].sum() / k\n",
    "        recall = df[\"label\"][:k].sum() / df[\"label\"].sum()\n",
    "        precision_at_k[user].append(precision)\n",
    "        recall_at_k[user].append(recall)\n",
    "\n",
    "# plot the precision-recall curve\n",
    "ax = sns.lineplot(recall_at_k[users[0]], precision_at_k[users[0]])\n",
    "ax.set_title(\"precision-recall curve for varying k\")\n",
    "ax.set_xlabel(\"recall@k\")\n",
    "ax.set_ylabel(\"precision@k\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Achieving a high precision@k means a large proportion of top-k recommended items are relevant to the user. Recall@k shows what proportion of all relevant items appeared in the top-k recommendations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Mean Average Precision (MAP)\n",
    "\n",
    "Precision@k is a good metric for understanding how many relevant recommendations we might make at each top-k. However, we would prefer relevant items to be recommended first when possible and should encode that into our evaluation metric. __Average Precision (AP)__ is a running average of precision@k, rewarding recommendations where the revelant items are seen earlier rather than later. When the averaged across all users for some k, the AP metric is called MAP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_ap(precision):\n",
    "    ap = [precision[0]]\n",
    "    for p in precision[1:]:\n",
    "        ap.append(ap[-1] + p)\n",
    "    ap = [x / (n + 1) for x, n in zip(ap, range(len(ap)))]\n",
    "    return ap\n",
    "\n",
    "ap_at_k = {user: calculate_ap(pk)\n",
    "           for user, pk in precision_at_k.items()}\n",
    "\n",
    "num_k = 500\n",
    "map_at_k = [sum([ap_at_k[user][k] for user in users]) / len(users)\n",
    "            for k in range(num_k)]\n",
    "print(\"MAP@50: {}\".format(map_at_k[49]))\n",
    "\n",
    "# plot average precision\n",
    "ax = sns.lineplot(range(num_k), map_at_k)\n",
    "ax.set_title(\"MAP@k for varying k\")\n",
    "ax.set_xlabel(\"k\")\n",
    "ax.set_ylabel(\"MAP\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Cleanup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cells clean up the BigQuery tables and AutoML Table Datasets that were created with this notebook to avoid additional charges for storage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 Delete the Model and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tables_client.delete_model(model_display_name=MODEL_DISPLAY_NAME)\n",
    "\n",
    "tables_client.delete_dataset(dataset_display_name=DATASET_DISPLAY_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 Delete BigQuery datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to delete BigQuery tables, make sure the service account linked to this notebook has a role with the `bigquery.tables.delete` permission such as `Big Query Data Owner`. The following command displays the current service account.\n",
    "\n",
    "IAM permissions can be adjusted [here](https://console.cloud.google.com/iam-admin/iam)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud config list account --format \"value(core.account)\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clean up the BigQuery tables created by this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete the prediction dataset.\n",
    "dataset_id = str(output_uri[5:].replace(\":\", \".\"))\n",
    "bq_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True)\n",
    "\n",
    "# Delete the training dataset.\n",
    "dataset_id = \"{0}.{1}\".format(PROJECT_ID, INPUT_BQ_DATASET)\n",
    "bq_client.delete_dataset(dataset_id, delete_contents=True, not_found_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
